{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 一、项目概述\n",
    "    本实验将基于上一实验的基础上，使用相同的数据集，构建模型，预测森林覆盖类型。与上一实验不同之处在于，上一实验主要侧重于数据预处理与构建模型的过程。本实验则主要侧重于Spark DataFrame API的使用"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 1.1 项目目标\n",
    "        该项目的目标是根据给定的森林覆盖类型相关数据集(带目标变量)预测森林覆盖类型。该项目的训练数据集是来自于美国林业局（USFS）资源信息系统数据库以及美国地质调查局。数据中的目标变量，即森林覆盖类型包含以下其中类型：\n",
    "+ 1 - Spruce/Fir\n",
    "+ 2 - Lodgepole Pine\n",
    "+ 3 - Ponderosa Pine\n",
    "+ 4 - Cottonwood/Willow\n",
    "+ 5 - Aspen\n",
    "+ 6 - Douglas-fir\n",
    "+ 7 - Krummholz"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 1.2 数据简介\n",
    "    训练集：data/CoverType/train.csv\n",
    "    测试集：data/CoverType/test.csv\n",
    "    训练集(train.csv)共总共15120个观测样本，包含特征数据与标注数据（森岭覆盖类型），测试集(test.csv)仅包含特征数据,无标注数据\n",
    "##### 1.3 数据集字段说明\n",
    "+ Elevation - 海拔\n",
    "+ Aspect - 经度\n",
    "+ Slope - 坡度\n",
    "+ Horizontal_Distance_To_Hydrology - 距离最近水源的水平距离\n",
    "+ Vertical_Distance_To_Hydrology - 距离最近水源的垂直距离\n",
    "+ Horizontal_Distance_To_Roadways - 距离最近道路的水品距离\n",
    "+ Hillshade_9am (0 to 255 index) - 夏至时器上午9点的山体阴影，范围为 0~255\n",
    "+ Hillshade_Noon (0 to 255 index) - 夏至时器下午的山体阴影，范围为 0~255\n",
    "+ Hillshade_3pm (0 to 255 index) - 夏至时器下午3点的山体阴影,范围为 0~255\n",
    "+ Horizontal_Distance_To_Fire_Points - 距离最近野外火源的水平距离\n",
    "+ Wilderness_Area (4 binary columns, 0 = absence or 1 = presence) - 荒野地区（0 不存在，1 存在），包含4种类型\n",
    "    + 四种荒野地区类型\n",
    "      + 1 - Rawah Wilderness Area\n",
    "      + 2 - Neota Wilderness Area\n",
    "      + 3 - Comanche Peak Wilderness Area\n",
    "      + 4 - Cache la Poudre Wilderness Area\n",
    "+ Soil_Type (40 binary columns, 0 = absence or 1 = presence) - 土壤类型（0 存在 1 不存在），包含40种类型\n",
    "+ Cover_Type (7 types, integers 1 to 7) - 森岭覆盖类型，即目标标量，1~7总共7种类型\n",
    "\n",
    "如上所述，数据集包括目标变量总共55列"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 二、数据集准备"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 2.1 加载数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.context import SparkContext\n",
    "from pyspark.sql.session import SparkSession"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sc = SparkContext(\"local[*]\",\"Forest Cover Type\")\n",
    "spark = SparkSession(sc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "covType = sc.textFile(\"data/CoverType/train.csv\")\n",
    "# 以逗号为分隔符分割每一行数据\n",
    "data = covType.map(lambda row : row.split(\",\"))\n",
    "# 获取数据集字段名称\n",
    "header = data.first()\n",
    "# 将字段名称从数据集中剔除\n",
    "data = data.filter(lambda row : row != header). \\\n",
    "            map(lambda row : [float(x) for x in row])\n",
    "# 将所有数据转为float类型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Id',\n",
       " 'Elevation',\n",
       " 'Aspect',\n",
       " 'Slope',\n",
       " 'Horizontal_Distance_To_Hydrology',\n",
       " 'Vertical_Distance_To_Hydrology',\n",
       " 'Horizontal_Distance_To_Roadways',\n",
       " 'Hillshade_9am',\n",
       " 'Hillshade_Noon',\n",
       " 'Hillshade_3pm',\n",
       " 'Horizontal_Distance_To_Fire_Points',\n",
       " 'Wilderness_Area1',\n",
       " 'Wilderness_Area2',\n",
       " 'Wilderness_Area3',\n",
       " 'Wilderness_Area4',\n",
       " 'Soil_Type1',\n",
       " 'Soil_Type2',\n",
       " 'Soil_Type3',\n",
       " 'Soil_Type4',\n",
       " 'Soil_Type5',\n",
       " 'Soil_Type6',\n",
       " 'Soil_Type7',\n",
       " 'Soil_Type8',\n",
       " 'Soil_Type9',\n",
       " 'Soil_Type10',\n",
       " 'Soil_Type11',\n",
       " 'Soil_Type12',\n",
       " 'Soil_Type13',\n",
       " 'Soil_Type14',\n",
       " 'Soil_Type15',\n",
       " 'Soil_Type16',\n",
       " 'Soil_Type17',\n",
       " 'Soil_Type18',\n",
       " 'Soil_Type19',\n",
       " 'Soil_Type20',\n",
       " 'Soil_Type21',\n",
       " 'Soil_Type22',\n",
       " 'Soil_Type23',\n",
       " 'Soil_Type24',\n",
       " 'Soil_Type25',\n",
       " 'Soil_Type26',\n",
       " 'Soil_Type27',\n",
       " 'Soil_Type28',\n",
       " 'Soil_Type29',\n",
       " 'Soil_Type30',\n",
       " 'Soil_Type31',\n",
       " 'Soil_Type32',\n",
       " 'Soil_Type33',\n",
       " 'Soil_Type34',\n",
       " 'Soil_Type35',\n",
       " 'Soil_Type36',\n",
       " 'Soil_Type37',\n",
       " 'Soil_Type38',\n",
       " 'Soil_Type39',\n",
       " 'Soil_Type40',\n",
       " 'Cover_Type']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 查看字段名称\n",
    "header"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "15120"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 统计数据\n",
    "data.count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1.0,\n",
       " 2596.0,\n",
       " 51.0,\n",
       " 3.0,\n",
       " 258.0,\n",
       " 0.0,\n",
       " 510.0,\n",
       " 221.0,\n",
       " 232.0,\n",
       " 148.0,\n",
       " 6279.0,\n",
       " 1.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 1.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " 5.0]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 查看第一行数据\n",
    "data.first()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 2.2 创建DataFrame\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----+---------+------+-----+--------------------------------+------------------------------+-------------------------------+-------------+--------------+-------------+----------------------------------+----------------+----------------+----------------+----------------+----------+----------+----------+----------+----------+----------+----------+----------+----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+----------+\n",
      "|  Id|Elevation|Aspect|Slope|Horizontal_Distance_To_Hydrology|Vertical_Distance_To_Hydrology|Horizontal_Distance_To_Roadways|Hillshade_9am|Hillshade_Noon|Hillshade_3pm|Horizontal_Distance_To_Fire_Points|Wilderness_Area1|Wilderness_Area2|Wilderness_Area3|Wilderness_Area4|Soil_Type1|Soil_Type2|Soil_Type3|Soil_Type4|Soil_Type5|Soil_Type6|Soil_Type7|Soil_Type8|Soil_Type9|Soil_Type10|Soil_Type11|Soil_Type12|Soil_Type13|Soil_Type14|Soil_Type15|Soil_Type16|Soil_Type17|Soil_Type18|Soil_Type19|Soil_Type20|Soil_Type21|Soil_Type22|Soil_Type23|Soil_Type24|Soil_Type25|Soil_Type26|Soil_Type27|Soil_Type28|Soil_Type29|Soil_Type30|Soil_Type31|Soil_Type32|Soil_Type33|Soil_Type34|Soil_Type35|Soil_Type36|Soil_Type37|Soil_Type38|Soil_Type39|Soil_Type40|Cover_Type|\n",
      "+----+---------+------+-----+--------------------------------+------------------------------+-------------------------------+-------------+--------------+-------------+----------------------------------+----------------+----------------+----------------+----------------+----------+----------+----------+----------+----------+----------+----------+----------+----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+----------+\n",
      "| 1.0|   2596.0|  51.0|  3.0|                           258.0|                           0.0|                          510.0|        221.0|         232.0|        148.0|                            6279.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       5.0|\n",
      "| 2.0|   2590.0|  56.0|  2.0|                           212.0|                          -6.0|                          390.0|        220.0|         235.0|        151.0|                            6225.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       5.0|\n",
      "| 3.0|   2804.0| 139.0|  9.0|                           268.0|                          65.0|                         3180.0|        234.0|         238.0|        135.0|                            6121.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       2.0|\n",
      "| 4.0|   2785.0| 155.0| 18.0|                           242.0|                         118.0|                         3090.0|        238.0|         238.0|        122.0|                            6211.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       2.0|\n",
      "| 5.0|   2595.0|  45.0|  2.0|                           153.0|                          -1.0|                          391.0|        220.0|         234.0|        150.0|                            6172.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       5.0|\n",
      "| 6.0|   2579.0| 132.0|  6.0|                           300.0|                         -15.0|                           67.0|        230.0|         237.0|        140.0|                            6031.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       2.0|\n",
      "| 7.0|   2606.0|  45.0|  7.0|                           270.0|                           5.0|                          633.0|        222.0|         225.0|        138.0|                            6256.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       5.0|\n",
      "| 8.0|   2605.0|  49.0|  4.0|                           234.0|                           7.0|                          573.0|        222.0|         230.0|        144.0|                            6228.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       5.0|\n",
      "| 9.0|   2617.0|  45.0|  9.0|                           240.0|                          56.0|                          666.0|        223.0|         221.0|        133.0|                            6244.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       5.0|\n",
      "|10.0|   2612.0|  59.0| 10.0|                           247.0|                          11.0|                          636.0|        228.0|         219.0|        124.0|                            6230.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       5.0|\n",
      "|11.0|   2612.0| 201.0|  4.0|                           180.0|                          51.0|                          735.0|        218.0|         243.0|        161.0|                            6222.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       5.0|\n",
      "|12.0|   2886.0| 151.0| 11.0|                           371.0|                          26.0|                         5253.0|        234.0|         240.0|        136.0|                            4051.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       2.0|\n",
      "|13.0|   2742.0| 134.0| 22.0|                           150.0|                          69.0|                         3215.0|        248.0|         224.0|         92.0|                            6091.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       2.0|\n",
      "|14.0|   2609.0| 214.0|  7.0|                           150.0|                          46.0|                          771.0|        213.0|         247.0|        170.0|                            6211.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       5.0|\n",
      "|15.0|   2503.0| 157.0|  4.0|                            67.0|                           4.0|                          674.0|        224.0|         240.0|        151.0|                            5600.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       5.0|\n",
      "|16.0|   2495.0|  51.0|  7.0|                            42.0|                           2.0|                          752.0|        224.0|         225.0|        137.0|                            5576.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       5.0|\n",
      "|17.0|   2610.0| 259.0|  1.0|                           120.0|                          -1.0|                          607.0|        216.0|         239.0|        161.0|                            6096.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       5.0|\n",
      "|18.0|   2517.0|  72.0|  7.0|                            85.0|                           6.0|                          595.0|        228.0|         227.0|        133.0|                            5607.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       5.0|\n",
      "|19.0|   2504.0|   0.0|  4.0|                            95.0|                           5.0|                          691.0|        214.0|         232.0|        156.0|                            5572.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       5.0|\n",
      "|20.0|   2503.0|  38.0|  5.0|                            85.0|                          10.0|                          741.0|        220.0|         228.0|        144.0|                            5555.0|             1.0|             0.0|             0.0|             0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|       0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        1.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|        0.0|       5.0|\n",
      "+----+---------+------+-----+--------------------------------+------------------------------+-------------------------------+-------------+--------------+-------------+----------------------------------+----------------+----------------+----------------+----------------+----------+----------+----------+----------+----------+----------+----------+----------+----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+----------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "dataFrame = spark.createDataFrame(data,header)\n",
    "dataFrame.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----+---------+------+-----+-------------+----------------+----------+----------+\n",
      "|  Id|Elevation|Aspect|Slope|Hillshade_9am|Wilderness_Area1|Soil_Type1|Cover_Type|\n",
      "+----+---------+------+-----+-------------+----------------+----------+----------+\n",
      "| 1.0|   2596.0|  51.0|  3.0|        221.0|             1.0|       0.0|       5.0|\n",
      "| 2.0|   2590.0|  56.0|  2.0|        220.0|             1.0|       0.0|       5.0|\n",
      "| 3.0|   2804.0| 139.0|  9.0|        234.0|             1.0|       0.0|       2.0|\n",
      "| 4.0|   2785.0| 155.0| 18.0|        238.0|             1.0|       0.0|       2.0|\n",
      "| 5.0|   2595.0|  45.0|  2.0|        220.0|             1.0|       0.0|       5.0|\n",
      "| 6.0|   2579.0| 132.0|  6.0|        230.0|             1.0|       0.0|       2.0|\n",
      "| 7.0|   2606.0|  45.0|  7.0|        222.0|             1.0|       0.0|       5.0|\n",
      "| 8.0|   2605.0|  49.0|  4.0|        222.0|             1.0|       0.0|       5.0|\n",
      "| 9.0|   2617.0|  45.0|  9.0|        223.0|             1.0|       0.0|       5.0|\n",
      "|10.0|   2612.0|  59.0| 10.0|        228.0|             1.0|       0.0|       5.0|\n",
      "|11.0|   2612.0| 201.0|  4.0|        218.0|             1.0|       0.0|       5.0|\n",
      "|12.0|   2886.0| 151.0| 11.0|        234.0|             1.0|       0.0|       2.0|\n",
      "|13.0|   2742.0| 134.0| 22.0|        248.0|             1.0|       0.0|       2.0|\n",
      "|14.0|   2609.0| 214.0|  7.0|        213.0|             1.0|       0.0|       5.0|\n",
      "|15.0|   2503.0| 157.0|  4.0|        224.0|             1.0|       0.0|       5.0|\n",
      "|16.0|   2495.0|  51.0|  7.0|        224.0|             1.0|       0.0|       5.0|\n",
      "|17.0|   2610.0| 259.0|  1.0|        216.0|             1.0|       0.0|       5.0|\n",
      "|18.0|   2517.0|  72.0|  7.0|        228.0|             1.0|       0.0|       5.0|\n",
      "|19.0|   2504.0|   0.0|  4.0|        214.0|             1.0|       0.0|       5.0|\n",
      "|20.0|   2503.0|  38.0|  5.0|        220.0|             1.0|       0.0|       5.0|\n",
      "+----+---------+------+-----+-------------+----------------+----------+----------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 由于数据列较多，可以选择部分数据列进行查看\n",
    "dataFrame.select('Id','Elevation','Aspect', 'Slope','Hillshade_9am',\n",
    "                 'Wilderness_Area1','Soil_Type1',\"Cover_Type\").show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 三、数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['__class__',\n",
       " '__delattr__',\n",
       " '__dict__',\n",
       " '__dir__',\n",
       " '__doc__',\n",
       " '__eq__',\n",
       " '__format__',\n",
       " '__ge__',\n",
       " '__getattr__',\n",
       " '__getattribute__',\n",
       " '__getitem__',\n",
       " '__gt__',\n",
       " '__hash__',\n",
       " '__init__',\n",
       " '__init_subclass__',\n",
       " '__le__',\n",
       " '__lt__',\n",
       " '__module__',\n",
       " '__ne__',\n",
       " '__new__',\n",
       " '__reduce__',\n",
       " '__reduce_ex__',\n",
       " '__repr__',\n",
       " '__setattr__',\n",
       " '__sizeof__',\n",
       " '__str__',\n",
       " '__subclasshook__',\n",
       " '__weakref__',\n",
       " '_collectAsArrow',\n",
       " '_jcols',\n",
       " '_jdf',\n",
       " '_jmap',\n",
       " '_jseq',\n",
       " '_lazy_rdd',\n",
       " '_repr_html_',\n",
       " '_sc',\n",
       " '_schema',\n",
       " '_sort_cols',\n",
       " '_support_repr_html',\n",
       " 'agg',\n",
       " 'alias',\n",
       " 'approxQuantile',\n",
       " 'cache',\n",
       " 'checkpoint',\n",
       " 'coalesce',\n",
       " 'colRegex',\n",
       " 'collect',\n",
       " 'columns',\n",
       " 'corr',\n",
       " 'count',\n",
       " 'cov',\n",
       " 'createGlobalTempView',\n",
       " 'createOrReplaceGlobalTempView',\n",
       " 'createOrReplaceTempView',\n",
       " 'createTempView',\n",
       " 'crossJoin',\n",
       " 'crosstab',\n",
       " 'cube',\n",
       " 'describe',\n",
       " 'distinct',\n",
       " 'drop',\n",
       " 'dropDuplicates',\n",
       " 'drop_duplicates',\n",
       " 'dropna',\n",
       " 'dtypes',\n",
       " 'exceptAll',\n",
       " 'explain',\n",
       " 'fillna',\n",
       " 'filter',\n",
       " 'first',\n",
       " 'foreach',\n",
       " 'foreachPartition',\n",
       " 'freqItems',\n",
       " 'groupBy',\n",
       " 'groupby',\n",
       " 'head',\n",
       " 'hint',\n",
       " 'intersect',\n",
       " 'intersectAll',\n",
       " 'isLocal',\n",
       " 'isStreaming',\n",
       " 'is_cached',\n",
       " 'join',\n",
       " 'limit',\n",
       " 'localCheckpoint',\n",
       " 'na',\n",
       " 'orderBy',\n",
       " 'persist',\n",
       " 'printSchema',\n",
       " 'randomSplit',\n",
       " 'rdd',\n",
       " 'registerTempTable',\n",
       " 'repartition',\n",
       " 'repartitionByRange',\n",
       " 'replace',\n",
       " 'rollup',\n",
       " 'sample',\n",
       " 'sampleBy',\n",
       " 'schema',\n",
       " 'select',\n",
       " 'selectExpr',\n",
       " 'show',\n",
       " 'sort',\n",
       " 'sortWithinPartitions',\n",
       " 'sql_ctx',\n",
       " 'stat',\n",
       " 'storageLevel',\n",
       " 'subtract',\n",
       " 'summary',\n",
       " 'take',\n",
       " 'toDF',\n",
       " 'toJSON',\n",
       " 'toLocalIterator',\n",
       " 'toPandas',\n",
       " 'union',\n",
       " 'unionAll',\n",
       " 'unionByName',\n",
       " 'unpersist',\n",
       " 'where',\n",
       " 'withColumn',\n",
       " 'withColumnRenamed',\n",
       " 'withWatermark',\n",
       " 'write',\n",
       " 'writeStream']"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dir(dataFrame)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 3.1 去除多余数据列"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Elevation',\n",
       " 'Aspect',\n",
       " 'Slope',\n",
       " 'Horizontal_Distance_To_Hydrology',\n",
       " 'Vertical_Distance_To_Hydrology',\n",
       " 'Horizontal_Distance_To_Roadways',\n",
       " 'Hillshade_9am',\n",
       " 'Hillshade_Noon',\n",
       " 'Hillshade_3pm',\n",
       " 'Horizontal_Distance_To_Fire_Points',\n",
       " 'Wilderness_Area1',\n",
       " 'Wilderness_Area2',\n",
       " 'Wilderness_Area3',\n",
       " 'Wilderness_Area4',\n",
       " 'Soil_Type1',\n",
       " 'Soil_Type2',\n",
       " 'Soil_Type3',\n",
       " 'Soil_Type4',\n",
       " 'Soil_Type5',\n",
       " 'Soil_Type6',\n",
       " 'Soil_Type7',\n",
       " 'Soil_Type8',\n",
       " 'Soil_Type9',\n",
       " 'Soil_Type10',\n",
       " 'Soil_Type11',\n",
       " 'Soil_Type12',\n",
       " 'Soil_Type13',\n",
       " 'Soil_Type14',\n",
       " 'Soil_Type15',\n",
       " 'Soil_Type16',\n",
       " 'Soil_Type17',\n",
       " 'Soil_Type18',\n",
       " 'Soil_Type19',\n",
       " 'Soil_Type20',\n",
       " 'Soil_Type21',\n",
       " 'Soil_Type22',\n",
       " 'Soil_Type23',\n",
       " 'Soil_Type24',\n",
       " 'Soil_Type25',\n",
       " 'Soil_Type26',\n",
       " 'Soil_Type27',\n",
       " 'Soil_Type28',\n",
       " 'Soil_Type29',\n",
       " 'Soil_Type30',\n",
       " 'Soil_Type31',\n",
       " 'Soil_Type32',\n",
       " 'Soil_Type33',\n",
       " 'Soil_Type34',\n",
       " 'Soil_Type35',\n",
       " 'Soil_Type36',\n",
       " 'Soil_Type37',\n",
       " 'Soil_Type38',\n",
       " 'Soil_Type39',\n",
       " 'Soil_Type40',\n",
       " 'Cover_Type']"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# id对于构建模型来说是无用特征，去除\n",
    "dataWithoutId = dataFrame.drop(\"Id\")\n",
    "# 输出所有列名\n",
    "dataWithoutId.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- Elevation: double (nullable = true)\n",
      " |-- Aspect: double (nullable = true)\n",
      " |-- Slope: double (nullable = true)\n",
      " |-- Horizontal_Distance_To_Hydrology: double (nullable = true)\n",
      " |-- Vertical_Distance_To_Hydrology: double (nullable = true)\n",
      " |-- Horizontal_Distance_To_Roadways: double (nullable = true)\n",
      " |-- Hillshade_9am: double (nullable = true)\n",
      " |-- Hillshade_Noon: double (nullable = true)\n",
      " |-- Hillshade_3pm: double (nullable = true)\n",
      " |-- Horizontal_Distance_To_Fire_Points: double (nullable = true)\n",
      " |-- Wilderness_Area1: double (nullable = true)\n",
      " |-- Wilderness_Area2: double (nullable = true)\n",
      " |-- Wilderness_Area3: double (nullable = true)\n",
      " |-- Wilderness_Area4: double (nullable = true)\n",
      " |-- Soil_Type1: double (nullable = true)\n",
      " |-- Soil_Type2: double (nullable = true)\n",
      " |-- Soil_Type3: double (nullable = true)\n",
      " |-- Soil_Type4: double (nullable = true)\n",
      " |-- Soil_Type5: double (nullable = true)\n",
      " |-- Soil_Type6: double (nullable = true)\n",
      " |-- Soil_Type7: double (nullable = true)\n",
      " |-- Soil_Type8: double (nullable = true)\n",
      " |-- Soil_Type9: double (nullable = true)\n",
      " |-- Soil_Type10: double (nullable = true)\n",
      " |-- Soil_Type11: double (nullable = true)\n",
      " |-- Soil_Type12: double (nullable = true)\n",
      " |-- Soil_Type13: double (nullable = true)\n",
      " |-- Soil_Type14: double (nullable = true)\n",
      " |-- Soil_Type15: double (nullable = true)\n",
      " |-- Soil_Type16: double (nullable = true)\n",
      " |-- Soil_Type17: double (nullable = true)\n",
      " |-- Soil_Type18: double (nullable = true)\n",
      " |-- Soil_Type19: double (nullable = true)\n",
      " |-- Soil_Type20: double (nullable = true)\n",
      " |-- Soil_Type21: double (nullable = true)\n",
      " |-- Soil_Type22: double (nullable = true)\n",
      " |-- Soil_Type23: double (nullable = true)\n",
      " |-- Soil_Type24: double (nullable = true)\n",
      " |-- Soil_Type25: double (nullable = true)\n",
      " |-- Soil_Type26: double (nullable = true)\n",
      " |-- Soil_Type27: double (nullable = true)\n",
      " |-- Soil_Type28: double (nullable = true)\n",
      " |-- Soil_Type29: double (nullable = true)\n",
      " |-- Soil_Type30: double (nullable = true)\n",
      " |-- Soil_Type31: double (nullable = true)\n",
      " |-- Soil_Type32: double (nullable = true)\n",
      " |-- Soil_Type33: double (nullable = true)\n",
      " |-- Soil_Type34: double (nullable = true)\n",
      " |-- Soil_Type35: double (nullable = true)\n",
      " |-- Soil_Type36: double (nullable = true)\n",
      " |-- Soil_Type37: double (nullable = true)\n",
      " |-- Soil_Type38: double (nullable = true)\n",
      " |-- Soil_Type39: double (nullable = true)\n",
      " |-- Soil_Type40: double (nullable = true)\n",
      " |-- Cover_Type: double (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 查看数据集字段描述\n",
    "dataWithoutId.printSchema()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由上图可知，所有字段类型均为double类型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 3.2 处理Cover_Type目标变量\n",
    "    使用决策树决策树算法构建模型要求 label 从 0 开始，所以要减 1。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---------+------+-----+-------------+----------------+----------+----------+--------------+\n",
      "|Elevation|Aspect|Slope|Hillshade_9am|Wilderness_Area1|Soil_Type1|Cover_Type|New_Cover_Type|\n",
      "+---------+------+-----+-------------+----------------+----------+----------+--------------+\n",
      "|   2596.0|  51.0|  3.0|        221.0|             1.0|       0.0|       5.0|           4.0|\n",
      "|   2590.0|  56.0|  2.0|        220.0|             1.0|       0.0|       5.0|           4.0|\n",
      "|   2804.0| 139.0|  9.0|        234.0|             1.0|       0.0|       2.0|           1.0|\n",
      "|   2785.0| 155.0| 18.0|        238.0|             1.0|       0.0|       2.0|           1.0|\n",
      "|   2595.0|  45.0|  2.0|        220.0|             1.0|       0.0|       5.0|           4.0|\n",
      "|   2579.0| 132.0|  6.0|        230.0|             1.0|       0.0|       2.0|           1.0|\n",
      "|   2606.0|  45.0|  7.0|        222.0|             1.0|       0.0|       5.0|           4.0|\n",
      "|   2605.0|  49.0|  4.0|        222.0|             1.0|       0.0|       5.0|           4.0|\n",
      "|   2617.0|  45.0|  9.0|        223.0|             1.0|       0.0|       5.0|           4.0|\n",
      "|   2612.0|  59.0| 10.0|        228.0|             1.0|       0.0|       5.0|           4.0|\n",
      "|   2612.0| 201.0|  4.0|        218.0|             1.0|       0.0|       5.0|           4.0|\n",
      "|   2886.0| 151.0| 11.0|        234.0|             1.0|       0.0|       2.0|           1.0|\n",
      "|   2742.0| 134.0| 22.0|        248.0|             1.0|       0.0|       2.0|           1.0|\n",
      "|   2609.0| 214.0|  7.0|        213.0|             1.0|       0.0|       5.0|           4.0|\n",
      "|   2503.0| 157.0|  4.0|        224.0|             1.0|       0.0|       5.0|           4.0|\n",
      "|   2495.0|  51.0|  7.0|        224.0|             1.0|       0.0|       5.0|           4.0|\n",
      "|   2610.0| 259.0|  1.0|        216.0|             1.0|       0.0|       5.0|           4.0|\n",
      "|   2517.0|  72.0|  7.0|        228.0|             1.0|       0.0|       5.0|           4.0|\n",
      "|   2504.0|   0.0|  4.0|        214.0|             1.0|       0.0|       5.0|           4.0|\n",
      "|   2503.0|  38.0|  5.0|        220.0|             1.0|       0.0|       5.0|           4.0|\n",
      "+---------+------+-----+-------------+----------------+----------+----------+--------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "dataWithNewType = dataWithoutId.withColumn(\"New_Cover_Type\",dataFrame[\"Cover_Type\"] - 1)\n",
    "dataWithNewType.select('Elevation','Aspect', 'Slope','Hillshade_9am',\n",
    "                 'Wilderness_Area1','Soil_Type1',\"Cover_Type\",\"New_Cover_Type\").show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---------+------+-----+-------------+----------------+----------+--------------+\n",
      "|Elevation|Aspect|Slope|Hillshade_9am|Wilderness_Area1|Soil_Type1|New_Cover_Type|\n",
      "+---------+------+-----+-------------+----------------+----------+--------------+\n",
      "|   2596.0|  51.0|  3.0|        221.0|             1.0|       0.0|           4.0|\n",
      "|   2590.0|  56.0|  2.0|        220.0|             1.0|       0.0|           4.0|\n",
      "|   2804.0| 139.0|  9.0|        234.0|             1.0|       0.0|           1.0|\n",
      "|   2785.0| 155.0| 18.0|        238.0|             1.0|       0.0|           1.0|\n",
      "|   2595.0|  45.0|  2.0|        220.0|             1.0|       0.0|           4.0|\n",
      "|   2579.0| 132.0|  6.0|        230.0|             1.0|       0.0|           1.0|\n",
      "|   2606.0|  45.0|  7.0|        222.0|             1.0|       0.0|           4.0|\n",
      "|   2605.0|  49.0|  4.0|        222.0|             1.0|       0.0|           4.0|\n",
      "|   2617.0|  45.0|  9.0|        223.0|             1.0|       0.0|           4.0|\n",
      "|   2612.0|  59.0| 10.0|        228.0|             1.0|       0.0|           4.0|\n",
      "|   2612.0| 201.0|  4.0|        218.0|             1.0|       0.0|           4.0|\n",
      "|   2886.0| 151.0| 11.0|        234.0|             1.0|       0.0|           1.0|\n",
      "|   2742.0| 134.0| 22.0|        248.0|             1.0|       0.0|           1.0|\n",
      "|   2609.0| 214.0|  7.0|        213.0|             1.0|       0.0|           4.0|\n",
      "|   2503.0| 157.0|  4.0|        224.0|             1.0|       0.0|           4.0|\n",
      "|   2495.0|  51.0|  7.0|        224.0|             1.0|       0.0|           4.0|\n",
      "|   2610.0| 259.0|  1.0|        216.0|             1.0|       0.0|           4.0|\n",
      "|   2517.0|  72.0|  7.0|        228.0|             1.0|       0.0|           4.0|\n",
      "|   2504.0|   0.0|  4.0|        214.0|             1.0|       0.0|           4.0|\n",
      "|   2503.0|  38.0|  5.0|        220.0|             1.0|       0.0|           4.0|\n",
      "+---------+------+-----+-------------+----------------+----------+--------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 将原始Cover_Type目标变量删除\n",
    "dataWithNewType = dataWithNewType.drop(\"Cover_Type\")\n",
    "dataWithNewType.select('Elevation','Aspect', 'Slope','Hillshade_9am',\n",
    "                 'Wilderness_Area1','Soil_Type1',\"New_Cover_Type\").show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由上图可知，新的数据列New_Cover_Type已经都再Cover_Type基础上减1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[4.0, 0.0, 6.0, 2.0, 1.0, 5.0, 3.0]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coverTypeCategory = dataWithNewType.rdd.map(lambda row: row[-1]).distinct().collect() \n",
    "coverTypeCategory"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### 3.3 拆分数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.mllib.regression import LabeledPoint\n",
    "labelPointRdd = dataWithNewType.rdd.map(lambda r: LabeledPoint(r[-1],r[:-1]))\n",
    "## 划分训练集、和测试集\n",
    "(trainData,testData) = labelPointRdd.randomSplit([8,2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[LabeledPoint(4.0, [2596.0,51.0,3.0,258.0,0.0,510.0,221.0,232.0,148.0,6279.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainData.take(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 四、模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DecisionTreeModel classifier of depth 5 with 51 nodes\n"
     ]
    }
   ],
   "source": [
    "# 导入决策树模型并训练\n",
    "from pyspark.mllib.tree import DecisionTree\n",
    "model = DecisionTree.trainClassifier(trainData,7,{})\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DecisionTreeModel classifier of depth 5 with 51 nodes\n",
      "  If (feature 0 <= 2670.5)\n",
      "   If (feature 0 <= 2373.5)\n",
      "    If (feature 6 <= 195.5)\n",
      "     If (feature 3 <= 15.0)\n",
      "      Predict: 3.0\n",
      "     Else (feature 3 > 15.0)\n",
      "      If (feature 30 <= 0.5)\n",
      "       Predict: 2.0\n",
      "      Else (feature 30 > 0.5)\n",
      "       Predict: 3.0\n",
      "    Else (feature 6 > 195.5)\n",
      "     If (feature 3 <= 15.0)\n",
      "      If (feature 12 <= 0.5)\n",
      "       Predict: 3.0\n",
      "      Else (feature 12 > 0.5)\n",
      "       Predict: 5.0\n",
      "     Else (feature 3 > 15.0)\n",
      "      Predict: 3.0\n",
      "   Else (feature 0 > 2373.5)\n",
      "    If (feature 17 <= 0.5)\n",
      "     If (feature 10 <= 0.5)\n",
      "      Predict: 5.0\n",
      "     Else (feature 10 > 0.5)\n",
      "      If (feature 0 <= 2513.5)\n",
      "       Predict: 4.0\n",
      "      Else (feature 0 > 2513.5)\n",
      "       Predict: 1.0\n",
      "    Else (feature 17 > 0.5)\n",
      "     If (feature 1 <= 95.5)\n",
      "      If (feature 3 <= 284.0)\n",
      "       Predict: 5.0\n",
      "      Else (feature 3 > 284.0)\n",
      "       Predict: 2.0\n",
      "     Else (feature 1 > 95.5)\n",
      "      If (feature 3 <= 15.0)\n",
      "       Predict: 5.0\n",
      "      Else (feature 3 > 15.0)\n",
      "       Predict: 2.0\n",
      "  Else (feature 0 > 2670.5)\n",
      "   If (feature 0 <= 3209.5)\n",
      "    If (feature 0 <= 2933.5)\n",
      "     If (feature 9 <= 2597.5)\n",
      "      Predict: 4.0\n",
      "     Else (feature 9 > 2597.5)\n",
      "      If (feature 2 <= 22.5)\n",
      "       Predict: 1.0\n",
      "      Else (feature 2 > 22.5)\n",
      "       Predict: 4.0\n",
      "    Else (feature 0 > 2933.5)\n",
      "     If (feature 0 <= 3039.5)\n",
      "      If (feature 3 <= 76.0)\n",
      "       Predict: 0.0\n",
      "      Else (feature 3 > 76.0)\n",
      "       Predict: 1.0\n",
      "     Else (feature 0 > 3039.5)\n",
      "      If (feature 7 <= 239.5)\n",
      "       Predict: 0.0\n",
      "      Else (feature 7 > 239.5)\n",
      "       Predict: 1.0\n",
      "   Else (feature 0 > 3209.5)\n",
      "    If (feature 0 <= 3293.5)\n",
      "     If (feature 52 <= 0.5)\n",
      "      If (feature 51 <= 0.5)\n",
      "       Predict: 0.0\n",
      "      Else (feature 51 > 0.5)\n",
      "       Predict: 6.0\n",
      "     Else (feature 52 > 0.5)\n",
      "      Predict: 6.0\n",
      "    Else (feature 0 > 3293.5)\n",
      "     If (feature 45 <= 0.5)\n",
      "      Predict: 6.0\n",
      "     Else (feature 45 > 0.5)\n",
      "      If (feature 2 <= 7.5)\n",
      "       Predict: 6.0\n",
      "      Else (feature 2 > 7.5)\n",
      "       Predict: 0.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 输出模型结构\n",
    "print(model.toDebugString())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由上图可知，训练所得决策树模型深度为5，共49个叶子节点"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 五、模型评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(1.0, 4.0), (1.0, 4.0), (1.0, 4.0), (4.0, 4.0), (1.0, 4.0)]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用AUC(Area under the Curve of ROC)来对模型进行评估\n",
    "## 使用模型对测试集进行预测\n",
    "predict = model.predict(testData.map(lambda p:p.features))\n",
    "predict_real = predict.zip(testData.map(lambda p: p.label))\n",
    "predict_real.take(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AUC=0.812702820246897\n"
     ]
    }
   ],
   "source": [
    "# 接着使用BinaryClassificationMetrics计算AUC\n",
    "from pyspark.mllib.evaluation import BinaryClassificationMetrics\n",
    "metrics = BinaryClassificationMetrics(predict_real)\n",
    "print(\"AUC=\"+str(metrics.areaUnderROC))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 六、使用模型进行预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "样本Id：15174.0,预测结果：Lodgepole Pine\n",
      "样本Id：15208.0,预测结果：Lodgepole Pine\n",
      "样本Id：15270.0,预测结果：Lodgepole Pine\n",
      "样本Id：15320.0,预测结果：Aspen\n",
      "样本Id：15436.0,预测结果：Lodgepole Pine\n",
      "样本Id：15767.0,预测结果：Lodgepole Pine\n",
      "样本Id：15830.0,预测结果：Lodgepole Pine\n",
      "样本Id：15949.0,预测结果：Lodgepole Pine\n",
      "样本Id：15971.0,预测结果：Lodgepole Pine\n",
      "样本Id：16102.0,预测结果：Lodgepole Pine\n",
      "样本Id：16219.0,预测结果：Lodgepole Pine\n",
      "样本Id：16256.0,预测结果：Aspen\n",
      "样本Id：16301.0,预测结果：Lodgepole Pine\n",
      "样本Id：16651.0,预测结果：Lodgepole Pine\n",
      "样本Id：16756.0,预测结果：Lodgepole Pine\n",
      "样本Id：16886.0,预测结果：Aspen\n",
      "样本Id：16918.0,预测结果：Lodgepole Pine\n",
      "样本Id：16935.0,预测结果：Lodgepole Pine\n",
      "样本Id：16946.0,预测结果：Lodgepole Pine\n",
      "样本Id：16968.0,预测结果：Aspen\n"
     ]
    }
   ],
   "source": [
    "def predict(sc,model):\n",
    "    # 加载测试集数据\n",
    "    testSet = sc.textFile(\"data/CoverType/test.csv\")\n",
    "    firstLine = testSet.first()\n",
    "    # 测试数据集预处理\n",
    "    testSet = testSet.filter(lambda row : row != firstLine). \\\n",
    "                      map(lambda row : row.split(\",\")). \\\n",
    "                      map(lambda row : [float(x) for x in row]). \\\n",
    "                      map(lambda row : LabeledPoint(row[0],row[1:]))\n",
    "    # 类别id到类别名称映射字典\n",
    "    DescDict={0:\"Spruce/Fir\",\n",
    "              1:\"Lodgepole Pine\",\n",
    "              2:\"Ponderosa Pine\",\n",
    "              3:\"Cottonwood/Willow\",\n",
    "              4:\"Aspen\",\n",
    "              5:\"Douglas-fir\",\n",
    "              6:\"Krummholz\"}\n",
    "    for sample in testSet.sample(False,0.01,111).take(20):\n",
    "        dataId = sample.label\n",
    "        prediction = model.predict(sample.features)\n",
    "        coverType = DescDict[prediction]\n",
    "        print(\"样本Id：{},预测结果：{}\".format(dataId,coverType))\n",
    "predict(sc,model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
